{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5b93ac521063c6a7",
   "metadata": {},
   "source": [
    "### 基于灰狼优化的遗传编程\n",
    "**GP入门系列教程地址：https://github.com/hengzhe-zhang/DEAP-GP-Tutorial**\n",
    "\n",
    "**前言：本文章旨在帮助研究灰狼算法的同学了解遗传编程，以及如何将灰狼算法应用到遗传编程中。文章不代表课题组学术观点。**\n",
    "\n",
    "**灰狼优化**和**灰狼**的关系就和**蚂蚁上树**与**蚂蚁**的关系是一样的。灰狼优化里面当然没有灰狼，正如蚂蚁上树里面也不会真的有蚂蚁一样。\n",
    "\n",
    "所谓灰狼优化，即Seyedali Mirjalili观察到灰狼种群分为alpha, beta, delta和omega狼，**alpha, beta, delta会带领omega狼**，从而设计的一种优化算法。\n",
    "\n",
    "灰狼算法现在有14000+的引用量，应该说还算是一个比较有影响力的算法。\n",
    "\n",
    "![灰狼优化](img/greywolfga.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c38417f7f5bd3ba6",
   "metadata": {},
   "source": [
    "### 实验问题\n",
    "\n",
    "本文的实验问题是GP领域最经典的符号回归问题，即根据训练数据，找到真实函数。\n",
    "\n",
    "在这里，我们的真实函数是$x^3 + x^2$。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 545,
   "id": "initial_id",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.462077Z",
     "start_time": "2024-02-25T02:53:29.350856400Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import operator\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "from deap import base, creator, tools, gp\n",
    "from deap.tools import selTournament\n",
    "\n",
    "np.random.seed(0)\n",
    "random.seed(0)\n",
    "\n",
    "\n",
    "# 符号回归\n",
    "def evalSymbReg(individual, pset):\n",
    "    # 编译GP树为函数\n",
    "    func = gp.compile(expr=individual, pset=pset)\n",
    "    # 计算均方误差（Mean Square Error，MSE）\n",
    "    mse = ((func(x) - (x ** 3 + x ** 2)) ** 2 for x in range(-10, 10))\n",
    "    return (math.fsum(mse),)\n",
    "\n",
    "\n",
    "# 创建个体和适应度函数\n",
    "creator.create(\"FitnessMin\", base.Fitness, weights=(-1.0,))\n",
    "creator.create(\"Individual\", gp.PrimitiveTree, fitness=creator.FitnessMin)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de05d4716285bc12",
   "metadata": {},
   "source": [
    "#### 选择算子\n",
    "经典的灰狼算法主要是用于优化连续优化问题，对于遗传编程，我们可以基于遗传编程算法的特点，稍加修改。\n",
    "\n",
    "在这里，我们将Top-3的个体作为alpha, beta, delta，剩下的个体作为omega。\n",
    "\n",
    "然后，我们随机选择alpha, beta, delta中的一个个体，或者omega中的一个个体，作为新一代的个体。\n",
    "\n",
    "这里，由于选择alpha, beta, delta的概率是0.5，因此相当于整个种群会被alpha, beta, delta个体引领。这也就是灰狼算法最核心的思想。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 546,
   "id": "8d45279c64374b2e",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.464592700Z",
     "start_time": "2024-02-25T02:53:29.413165300Z"
    }
   },
   "outputs": [],
   "source": [
    "from operator import attrgetter\n",
    "\n",
    "\n",
    "def selGWO(individuals, k, fit_attr=\"fitness\"):\n",
    "    # 根据适应度对个体进行排序；最优个体排在前面\n",
    "    sorted_individuals = sorted(individuals, key=attrgetter(fit_attr), reverse=True)\n",
    "\n",
    "    # 确定Top-3个体（alpha, beta, delta）\n",
    "    leaders = sorted_individuals[:3]\n",
    "\n",
    "    # 剩余的个体被视为omega\n",
    "    omega = sorted_individuals[3:]\n",
    "\n",
    "    # 选择交叉/变异的个体\n",
    "    return [random.choice(leaders) if random.random() < 0.5 else random.choice(omega) for _ in range(k)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 547,
   "id": "a64c3b6263b43be3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.480623900Z",
     "start_time": "2024-02-25T02:53:29.469592400Z"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "# 定义函数集合和终端集合\n",
    "pset = gp.PrimitiveSet(\"MAIN\", arity=1)\n",
    "pset.addPrimitive(operator.add, 2)\n",
    "pset.addPrimitive(operator.sub, 2)\n",
    "pset.addPrimitive(operator.mul, 2)\n",
    "pset.addPrimitive(operator.neg, 1)\n",
    "pset.addEphemeralConstant(\"rand101\", lambda: random.randint(-1, 1))\n",
    "pset.renameArguments(ARG0='x')\n",
    "\n",
    "# 定义遗传编程操作\n",
    "toolbox = base.Toolbox()\n",
    "toolbox.register(\"expr\", gp.genHalfAndHalf, pset=pset, min_=0, max_=6)\n",
    "toolbox.register(\"individual\", tools.initIterate, creator.Individual, toolbox.expr)\n",
    "toolbox.register(\"population\", tools.initRepeat, list, toolbox.individual)\n",
    "toolbox.register(\"compile\", gp.compile, pset=pset)\n",
    "toolbox.register(\"evaluate\", evalSymbReg, pset=pset)\n",
    "toolbox.register(\"select\", selGWO)\n",
    "toolbox.register(\"mate\", gp.cxOnePoint)\n",
    "toolbox.register(\"mutate\", gp.mutUniform, expr=toolbox.expr, pset=pset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5938c2ff4b2308",
   "metadata": {},
   "source": [
    "### 实际结果\n",
    "\n",
    "现在，可以运行一下，看看实际的结果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 548,
   "id": "8c72542ecc15cf3c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.612186600Z",
     "start_time": "2024-02-25T02:53:29.484441600Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   \t      \t                              fitness                              \t                      size                     \n",
      "   \t      \t-------------------------------------------------------------------\t-----------------------------------------------\n",
      "gen\tnevals\tavg        \tgen\tmax        \tmin\tnevals\tstd        \tavg  \tgen\tmax\tmin\tnevals\tstd    \n",
      "0  \t100   \t1.85626e+09\t0  \t1.63558e+11\t20 \t100   \t1.63417e+10\t13.01\t0  \t109\t1  \t100   \t18.8199\n",
      "1  \t97    \t2.78507e+11\t1  \t2.77639e+13\t20 \t97    \t2.7624e+12 \t16.59\t1  \t84 \t1  \t97    \t14.7764\n",
      "2  \t91    \t1.45053e+10\t2  \t1.44943e+12\t0  \t91    \t1.44215e+11\t18.93\t2  \t66 \t1  \t91    \t8.99139\n",
      "3  \t92    \t1.05095e+14\t3  \t1.0508e+16 \t0  \t92    \t1.04554e+15\t22.02\t3  \t115\t1  \t92    \t13.124 \n",
      "4  \t90    \t3.5126e+08 \t4  \t1.76023e+10\t0  \t90    \t2.43413e+09\t21.97\t4  \t43 \t1  \t90    \t7.74914\n",
      "5  \t88    \t3.5427e+08 \t5  \t1.76023e+10\t0  \t88    \t2.40632e+09\t24.41\t5  \t43 \t6  \t88    \t6.87036\n",
      "6  \t91    \t3.62014e+08\t6  \t1.76022e+10\t0  \t91    \t2.4341e+09 \t26.1 \t6  \t98 \t4  \t91    \t10.9968\n",
      "7  \t84    \t2.02267e+12\t7  \t2.02136e+14\t0  \t84    \t2.01122e+13\t25.43\t7  \t72 \t3  \t84    \t12.1254\n",
      "8  \t94    \t6.17398e+08\t8  \t2.30114e+10\t0  \t94    \t3.40487e+09\t25.77\t8  \t74 \t10 \t94    \t11.0931\n",
      "9  \t86    \t6.03678e+08\t9  \t2.15822e+10\t0  \t86    \t3.36656e+09\t22.36\t9  \t76 \t3  \t86    \t12.3114\n",
      "10 \t82    \t7.0262e+08 \t10 \t1.76331e+10\t0  \t82    \t3.36997e+09\t25.97\t10 \t67 \t3  \t82    \t13.3045\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "from deap import algorithms\n",
    "\n",
    "# 定义统计指标\n",
    "stats_fit = tools.Statistics(lambda ind: ind.fitness.values)\n",
    "stats_size = tools.Statistics(len)\n",
    "mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)\n",
    "mstats.register(\"avg\", numpy.mean)\n",
    "mstats.register(\"std\", numpy.std)\n",
    "mstats.register(\"min\", numpy.min)\n",
    "mstats.register(\"max\", numpy.max)\n",
    "\n",
    "# 使用默认算法\n",
    "population = toolbox.population(n=100)\n",
    "hof = tools.HallOfFame(1)\n",
    "_ = algorithms.eaSimple(population=population,\n",
    "                        toolbox=toolbox, cxpb=0.9, mutpb=0.1, ngen=10, stats=mstats, halloffame=hof,\n",
    "                        verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 549,
   "id": "2fec33b8c50c9e16",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.614209500Z",
     "start_time": "2024-02-25T02:53:29.611124100Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "add(add(neg(x), mul(mul(x, x), sub(x, -1))), neg(mul(add(0, x), neg(1))))\n"
     ]
    }
   ],
   "source": [
    "print(str(hof[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 550,
   "id": "5e9a8742f4621f52",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.830372600Z",
     "start_time": "2024-02-25T02:53:29.616722700Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   \t      \t                              fitness                              \t                      size                     \n",
      "   \t      \t-------------------------------------------------------------------\t-----------------------------------------------\n",
      "gen\tnevals\tavg        \tgen\tmax       \tmin   \tnevals\tstd        \tavg  \tgen\tmax\tmin\tnevals\tstd    \n",
      "0  \t100   \t5.51171e+12\t0  \t5.5117e+14\t154094\t100   \t5.48407e+13\t12.95\t0  \t100\t1  \t100   \t20.4931\n",
      "1  \t95    \t2.67999e+06\t1  \t3.36682e+06\t154094\t95    \t426386     \t7.25 \t1  \t59 \t1  \t95    \t12.0361\n",
      "2  \t91    \t2.65683e+06\t2  \t1.11493e+07\t154094\t91    \t1.01201e+06\t9.33 \t2  \t62 \t1  \t91    \t14.128 \n",
      "3  \t88    \t2.64718e+06\t3  \t2.828e+07  \t40666 \t88    \t2.87378e+06\t17.97\t3  \t100\t1  \t88    \t21.0701\n",
      "4  \t91    \t2.52386e+10\t4  \t2.40504e+12\t670   \t91    \t2.39326e+11\t33.98\t4  \t104\t1  \t91    \t25.6355\n",
      "5  \t95    \t4.33015e+06\t5  \t2.3502e+08 \t670   \t95    \t2.3649e+07 \t52.02\t5  \t142\t3  \t95    \t25.333 \n",
      "6  \t86    \t1.46688e+12\t6  \t1.46688e+14\t20    \t86    \t1.45952e+13\t57.23\t6  \t114\t7  \t86    \t22.9215\n",
      "7  \t95    \t5.50361e+06\t7  \t2.18106e+08\t20    \t95    \t3.01262e+07\t59.23\t7  \t133\t16 \t95    \t23.1477\n",
      "8  \t86    \t3.7073e+08 \t8  \t1.80501e+10\t20    \t86    \t2.49479e+09\t57.93\t8  \t133\t5  \t86    \t24.5647\n",
      "9  \t98    \t5.09515e+06\t9  \t2.23621e+08\t0     \t98    \t3.09103e+07\t61.48\t9  \t140\t18 \t98    \t29.1806\n",
      "10 \t89    \t5.50028e+08\t10 \t1.85126e+10\t0     \t89    \t3.1095e+09 \t64.09\t10 \t149\t11 \t89    \t31.7535\n"
     ]
    }
   ],
   "source": [
    "toolbox.register(\"select\", selTournament, tournsize=3)\n",
    "population = toolbox.population(n=100)\n",
    "hof = tools.HallOfFame(1)\n",
    "_ = algorithms.eaSimple(population=population,\n",
    "                        toolbox=toolbox, cxpb=0.9, mutpb=0.1, ngen=10, stats=mstats, halloffame=hof,\n",
    "                        verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 551,
   "id": "b66a9c085b6ac8f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-02-25T02:53:29.832024900Z",
     "start_time": "2024-02-25T02:53:29.829866900Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sub(neg(sub(sub(neg(add(1, add(1, x))), mul(add(1, x), mul(x, x))), mul(neg(neg(x)), -1))), sub(sub(neg(sub(add(-1, x), mul(1, x))), -1), neg(neg(0))))\n"
     ]
    }
   ],
   "source": [
    "print(str(hof[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e6257cae5d6e632",
   "metadata": {},
   "source": [
    "从结果可以看出，灰狼优化和传统的Tournament算子都可以成功地找到真实函数。相比之下，灰狼优化可以在更少的迭代次数内找到真实函数。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
